library(panelView)
library(zoo)
library(synthdid)
library(fect)
library(ggplot2)
library(data.table)
library(parallel)
library(doParallel)

sink("power_results.txt", split=TRUE)

source("staggered_synth.R")

laws = read.csv("statelaws.csv")
turnout = read.csv("stateturnout.csv")
laws = laws[laws$State!="",]
results <- list()

for (i in 1:nrow(laws)) {
  for (j in 2:ncol(laws)) { 
    if (!is.na(laws[i, j])) {
      if (laws[i, j] == "X") {
        results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = TRUE)))
      } else {
        results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = FALSE)))
      }
    } else {
      results <- c(results, list(data.frame(STATE = laws$State[i], YEAR = colnames(laws)[j], Law = FALSE)))
    }
  }
}
laws <- do.call(rbind, results)
laws$YEAR <- substring(laws$YEAR, 2, 10000)

data = merge(laws, turnout, by = c("STATE", "YEAR"))
data = data[!(data$STATE %in% c("Texas")), ]
data = data[data$YEAR > 1982 & data$YEAR != 2022, ]
data = data[as.numeric(data$YEAR) %% 4 == 2, ]
data$Y = as.numeric(gsub(",", "", data$VOTE_FOR_HIGHEST_OFFICE)) / as.numeric(gsub(",", "", data$VEP)) * 100

names(data)[names(data) == "Law"]   <- "W"
names(data)[names(data) == "STATE"] <- "State"
names(data)[names(data) == "YEAR"]  <- "Year"

data$W = as.numeric(data$W)


data$group = as.numeric(data$State %in% c("Arizona", "Georgia", "Indiana", "Ohio")) + 2*as.numeric(data$State %in% c("Alabama", "Kansas", "Mississippi", "North Dakota", "Tennessee", "Virginia", "Wisconsin"))

data.sim <- data[data$group %in% 0, ]

n_cores <- detectCores() - 1
cl <- makeCluster(n_cores)
registerDoParallel(cl)

run_simulation <- function(iter, data.sim, states) {
  
  source("staggered_synth.R")
  

  set.seed(iter)
  
  adopters2006 <- sample(states, 3) #3
  remaining    <- setdiff(states, adopters2006)
  
  adopters2008 <- sample(remaining, 1) #1
  remaining    <- setdiff(remaining, adopters2008)
  
  adopters2012 <- sample(remaining, 0) #2 for presidential/all elections, 3 for midterms
  remaining    <- setdiff(remaining, adopters2012)
  
  adopters2014 <- sample(remaining, 0) #2 for presidential/all elections, 3 for midterms
  remaining    <- setdiff(remaining, adopters2014)
  
  adopters2016 <- sample(remaining, 0) #1
  

  data.sim$W2 <- 0L
  data.sim$W2[data.sim$State %in% adopters2006 & data.sim$Year >= 2006] <- 1L
  data.sim$W2[data.sim$State %in% adopters2008 & data.sim$Year >= 2008] <- 1L
  data.sim$W2[data.sim$State %in% adopters2012 & data.sim$Year >= 2012] <- 1L
  data.sim$W2[data.sim$State %in% adopters2014 & data.sim$Year >= 2014] <- 1L
  data.sim$W2[data.sim$State %in% adopters2016 & data.sim$Year >= 2016] <- 1L
  
  data.sim$group2 <- 0L
  data.sim$group2[data.sim$State %in% adopters2006 | data.sim$State %in% adopters2008] <- 1L
  data.sim$group2[data.sim$State %in% adopters2012 | data.sim$State %in% adopters2014 | data.sim$State %in% adopters2016] <- 2L
  

  treated_all <- data.sim$W2 == 1
  

  size_vec <- as.numeric(gsub(",", "", data.sim$VEP))
  

  p_base   <- data.sim$Y / 100
  p_inflat <- ifelse(treated_all, 0.03, 0)
  
  data.sim$Y2 <- data.sim$Y

  data.sim$Y2[treated_all] <- data.sim$Y[treated_all] - rbinom(
    n    = sum(treated_all),
    size = size_vec[treated_all],
    prob = p_inflat[treated_all]
  ) / size_vec[treated_all] * 100
  

  dt <- data.frame(
    Unit = data.sim$State,
    Time = as.integer(data.sim$Year),
    Y    = data.sim$Y2,
    W    = data.sim$W2,
    group = data.sim$group2
  )

  
  fit <- StaggeredSynthDiD(
    data       = dt,
    unit       = "Unit",
    time       = "Time",
    outcome    = "Y",
    treatment  = "W",
    vcov       = "jack"
  )

  p_val <- 2 * (1 - pnorm(abs(fit$Estimate / fit$SE)))
  
  return(list(
    estimate = fit$Estimate,
    significant = p_val < 0.05,
    p_value = p_val,
    se = fit$SE
  ))
}


states <- unique(data.sim$State)
niter <- 1000  

cat("Starting parallel simulation with", n_cores, "cores...\n")
start.time <- Sys.time()

clusterExport(cl, c("data.sim", "states"))

results <- parLapply(cl, 1:niter, run_simulation, data.sim = data.sim, states = states)

stopCluster(cl)

time.taken <- Sys.time() - start.time

att_est <- sapply(results, function(x) x$estimate)
att_sig <- sapply(results, function(x) x$significant)
p_values <- sapply(results, function(x) x$p_value)
std_errors <- sapply(results, function(x) x$se)

valid_results <- !is.na(att_est)
att_est <- att_est[valid_results]
att_sig <- att_sig[valid_results]
p_values <- p_values[valid_results]
std_errors <- std_errors[valid_results]


cat("Parallel processing completed in", round(time.taken, 2), attr(time.taken, "units"), "\n")
cat("Valid simulations:", length(att_est), "of", niter, "\n")
cat("Significant at 5%:", sum(att_sig), "of", length(att_est), "\n")
cat("Proportion significant:", round(mean(att_sig), 3), "\n")
if(sum(att_sig) > 0) {
  cat("Mean ATT (signif.):", round(mean(att_est[att_sig]), 4), "\n")
}
cat("Mean ATT (all):", round(mean(att_est), 4), "\n")
cat("SD ATT:", round(sd(att_est), 4), "\n")

sink()